{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "latest version of fastai now installed\r\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp callback.tracker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *\n",
    "from fastai.callback.progress import *\n",
    "from fastai.callback.fp16 import MixedPrecision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tracking callbacks\n",
    "\n",
    "> Callbacks that make decisions depending how a monitored metric/loss behaves"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TerminateOnNaNCallback -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class TerminateOnNaNCallback(Callback):\n",
    "    \"A `Callback` that terminates training if loss is NaN.\"\n",
    "    run_before=Recorder\n",
    "\n",
    "    def after_batch(self):\n",
    "        \"Test if `last_loss` is NaN and interrupts training.\"\n",
    "        if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>215824033216875417685323736416256.000000</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = synth_learner()\n",
    "learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(learn.recorder.losses) < 10 * len(learn.dls.train)\n",
    "for l in learn.recorder.losses:\n",
    "    assert not torch.isinf(l) and not torch.isnan(l) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TrackerCallback -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class TrackerCallback(Callback):\n",
    "    \"A `Callback` that keeps track of the best value in `monitor`.\"\n",
    "    remove_on_fetch,run_after = True,Recorder\n",
    "\n",
    "    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., reset_on_fit=True):\n",
    "        if comp is None: comp = np.less if 'loss' in monitor or 'error' in monitor else np.greater\n",
    "        if comp == np.less: min_delta *= -1\n",
    "        self.monitor,self.comp,self.min_delta,self.reset_on_fit,self.best= monitor,comp,min_delta,reset_on_fit,None\n",
    "\n",
    "    def before_fit(self):\n",
    "        \"Prepare the monitored value\"\n",
    "        self.run = not hasattr(self, \"lr_finder\") and not hasattr(self, \"gather_preds\")\n",
    "        if self.reset_on_fit or self.best is None: self.best = float('inf') if self.comp == np.less else -float('inf')\n",
    "        assert self.monitor in self.recorder.metric_names[1:]\n",
    "        self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)\n",
    "\n",
    "    def after_epoch(self):\n",
    "        \"Compare the last value to the best up to now\"\n",
    "        val = self.recorder.values[-1][self.idx]\n",
    "        if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True\n",
    "        else: self.new_best = False\n",
    "\n",
    "    def after_fit(self): self.run=True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When implementing a `Callback` that has behavior that depends on the best value of a metric or loss, subclass this `Callback` and use its `best` (for best value so far) and `new_best` (there was a new best value this epoch) attributes. If you want to maintain `best` over subsequent calls to `fit` (e.g., `Learner.fit_one_cycle`), set `reset_on_fit` = True.\n",
    "\n",
    "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "class FakeRecords(Callback):\n",
    "    run_after=Recorder\n",
    "    run_before=TrackerCallback\n",
    "    \n",
    "    def __init__(self, monitor, values): self.monitor,self.values = monitor,values\n",
    "        \n",
    "    def before_fit(self):   self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)\n",
    "    def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch]\n",
    "        \n",
    "class TestTracker(Callback):\n",
    "    run_after=TrackerCallback\n",
    "    def before_fit(self): self.bests,self.news = [],[]\n",
    "    def after_epoch(self): \n",
    "        self.bests.append(self.tracker.best)\n",
    "        self.news.append(self.tracker.new_best)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "learn = synth_learner(n_trn=2, cbs=TestTracker())\n",
    "cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])]\n",
    "with learn.no_logging(): learn.fit(2, cbs=cbs)\n",
    "test_eq(learn.test_tracker.bests, [0.2, 0.1])\n",
    "test_eq(learn.test_tracker.news,  [True,True])\n",
    "\n",
    "#With a min_delta\n",
    "cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])]\n",
    "with learn.no_logging(): learn.fit(2, cbs=cbs)\n",
    "test_eq(learn.test_tracker.bests, [0.2, 0.2])\n",
    "test_eq(learn.test_tracker.news,  [True,False])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "#By default metrics have to be bigger at each epoch.\n",
    "def tst_metric(out,targ): return F.mse_loss(out,targ)\n",
    "learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)\n",
    "cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])]\n",
    "with learn.no_logging(): learn.fit(2, cbs=cbs)\n",
    "test_eq(learn.test_tracker.bests, [0.2, 0.2])\n",
    "test_eq(learn.test_tracker.news,  [True,False])\n",
    "\n",
    "#This can be overwritten by passing `comp=np.less`.\n",
    "learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)\n",
    "cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])]\n",
    "with learn.no_logging(): learn.fit(2, cbs=cbs)\n",
    "test_eq(learn.test_tracker.bests, [0.2, 0.1])\n",
    "test_eq(learn.test_tracker.news,  [True,True])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "#Setting reset_on_fit=True will maintain the \"best\" value over subsequent calls to fit\n",
    "learn = synth_learner(n_val=2, cbs=TrackerCallback(monitor='tst_metric', reset_on_fit=False), metrics=tst_metric)\n",
    "tracker_cb = learn.cbs.filter(lambda cb: isinstance(cb, TrackerCallback))[0]\n",
    "with learn.no_logging(): learn.fit(1)\n",
    "first_best = tracker_cb.best\n",
    "with learn.no_logging(): learn.fit(1)\n",
    "test_eq(tracker_cb.best, first_best)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "#A tracker callback is not run during an lr_find\n",
    "from fastai.callback.schedule import *\n",
    "learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric)\n",
    "learn.lr_find(num_it=15, show_plot=False)\n",
    "assert not hasattr(learn, 'new_best')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EarlyStoppingCallback -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class EarlyStoppingCallback(TrackerCallback):\n",
    "    \"A `TrackerCallback` that terminates training when monitored quantity stops improving.\"\n",
    "    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, reset_on_fit=True):\n",
    "        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)\n",
    "        self.patience = patience\n",
    "\n",
    "    def before_fit(self): self.wait = 0; super().before_fit()\n",
    "    def after_epoch(self):\n",
    "        \"Compare the value monitored to its best score and maybe stop training.\"\n",
    "        super().after_epoch()\n",
    "        if self.new_best: self.wait = 0\n",
    "        else:\n",
    "            self.wait += 1\n",
    "            if self.wait >= self.patience:\n",
    "                print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')\n",
    "                raise CancelFitException()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. `patience` is the number of epochs you're willing to wait without improvement."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>mse_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>7.289688</td>\n",
       "      <td>6.762055</td>\n",
       "      <td>6.762055</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>7.300182</td>\n",
       "      <td>6.762038</td>\n",
       "      <td>6.762038</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>7.304251</td>\n",
       "      <td>6.762012</td>\n",
       "      <td>6.762012</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No improvement since epoch 0: early stopping\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner(n_trn=2, metrics=F.mse_loss)\n",
    "learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='mse_loss', min_delta=0.1, patience=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(#2) [6.762012481689453,6.762012481689453]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.validate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>17.360760</td>\n",
       "      <td>18.424431</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>17.410801</td>\n",
       "      <td>18.424389</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>17.400934</td>\n",
       "      <td>18.424322</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No improvement since epoch 0: early stopping\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner(n_trn=2)\n",
    "learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_eq(len(learn.recorder.values), 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SaveModelCallback -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class SaveModelCallback(TrackerCallback):\n",
    "    \"A `TrackerCallback` that saves the model's best during training and loads it at the end.\"\n",
    "    _only_train_loop = True\n",
    "    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False,\n",
    "                 with_opt=False, reset_on_fit=True):\n",
    "        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)\n",
    "        # keep track of file path for loggers\n",
    "        self.last_saved_path = None\n",
    "        store_attr('fname,every_epoch,with_opt')\n",
    "\n",
    "    def _save(self, name): self.last_saved_path = self.learn.save(name, with_opt=self.with_opt)\n",
    "\n",
    "    def after_epoch(self):\n",
    "        \"Compare the value monitored to its best score and save if best.\"\n",
    "        if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')\n",
    "        else: #every improvement\n",
    "            super().after_epoch()\n",
    "            if self.new_best:\n",
    "                print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')\n",
    "                self._save(f'{self.fname}')\n",
    "\n",
    "    def after_fit(self, **kwargs):\n",
    "        \"Load the best model.\"\n",
    "        if not self.every_epoch: self.learn.load(f'{self.fname}', with_opt=self.with_opt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. Model will be saved in `learn.path/learn.model_dir/name.pth`, maybe `every_epoch` or at each improvement of the monitored quantity. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>16.789867</td>\n",
       "      <td>15.243573</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>16.694817</td>\n",
       "      <td>14.915352</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Better model found at epoch 0 with valid_loss value: 15.243573188781738.\n",
      "Better model found at epoch 1 with valid_loss value: 14.915351867675781.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>16.171261</td>\n",
       "      <td>14.440922</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>15.897676</td>\n",
       "      <td>13.870268</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')\n",
    "learn.fit(n_epoch=2, cbs=SaveModelCallback())\n",
    "assert (Path.cwd()/'tmp/models/model.pth').exists()\n",
    "learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))\n",
    "for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()\n",
    "shutil.rmtree(Path.cwd()/'tmp')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ReduceLROnPlateau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class ReduceLROnPlateau(TrackerCallback):\n",
    "    \"A `TrackerCallback` that reduces learning rate when a metric has stopped improving.\"\n",
    "    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10., min_lr=0, reset_on_fit=True):\n",
    "        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)\n",
    "        self.patience,self.factor,self.min_lr = patience,factor,min_lr\n",
    "\n",
    "    def before_fit(self): self.wait = 0; super().before_fit()\n",
    "    def after_epoch(self):\n",
    "        \"Compare the value monitored to its best score and reduce LR by `factor` if no improvement.\"\n",
    "        super().after_epoch()\n",
    "        if self.new_best: self.wait = 0\n",
    "        else:\n",
    "            self.wait += 1\n",
    "            if self.wait >= self.patience:\n",
    "                old_lr = self.opt.hypers[-1]['lr']\n",
    "                for h in self.opt.hypers: h['lr'] = max(h['lr'] / self.factor, self.min_lr)\n",
    "                self.wait = 0\n",
    "                if self.opt.hypers[-1][\"lr\"] < old_lr:\n",
    "                    print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1][\"lr\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>25.698519</td>\n",
       "      <td>40.559010</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>25.731672</td>\n",
       "      <td>40.558941</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>25.716520</td>\n",
       "      <td>40.558846</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>25.729078</td>\n",
       "      <td>40.558838</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2: reducing lr to 1e-08\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner(n_trn=2)\n",
    "learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_eq(learn.opt.hypers[-1]['lr'], 1e-8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>6.026482</td>\n",
       "      <td>5.962046</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>6.027329</td>\n",
       "      <td>5.962038</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>6.024285</td>\n",
       "      <td>5.962028</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>6.024290</td>\n",
       "      <td>5.962025</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>6.022497</td>\n",
       "      <td>5.962022</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>6.023217</td>\n",
       "      <td>5.962019</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2: reducing lr to 1e-08\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner(n_trn=2)\n",
    "learn.fit(n_epoch=6, lr=5e-8, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2, min_lr=1e-8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_eq(learn.opt.hypers[-1]['lr'], 1e-8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
